import json
import os

import numpy as np
import torch

from ModularUtils.Experiment_Class import Experiment
from ModularUtils.FunctionsConstant import plot_lines
from ModularUtils.Functions_Plot_Results import plot_saved_results
from Train_By_Components.Causal_TrainGraph import set_trainGraph

delta = 3
epochs =400

last_exp=path
kl_diff = {}
tvd_diff={}
dashed=[]
solid=[]
dist_keys={'P(X0,W1,W0,Y0|do(X1X2))': '$mod\_P(X_0,W_1,W_0,Y_0|do(X_1,X_2))$',
     'P(X1,X2,W1,Y1|do(X0))': '$mod\_P(X_1,X_2,W_1,Y_1|do(X_0))$',
     'P(X0,X1,X2,W1,Y1,W0,Y0)': '$mod\_P(V)$'
    }

print("tvd diffs")
for dist in dist_keys:
    tvd_diff[dist_keys[dist]] = torch.load(last_exp + "/tvd/" + dist).detach().cpu().numpy()[0:150]
    kl_diff[dist_keys[dist]] = torch.load(last_exp + "/kl/" + dist).detach().cpu().numpy()[0:150]



last_exp = "last_exp_path"
new_dist=(
{
          'P(X0,W1,W0,Y0|do(X1X2))': '$full\_P(X_0,W_1,W_0,Y_0|do(X_1,X_2))$',
          'P(X1,X2,W1,Y1|do(X0))': '$full\_P(X_1,X_2,W_1,Y_1|do(X_0))$',
          'P(X0,X1,X2,W1,Y1,W0,Y0)': '$full\_P(V)$'}
    )

dashed= list(new_dist.values())









dist_keys.update(new_dist)
solid = new_dist.values()
print("tvd diffs")
for dist in new_dist:
    tvd_diff[dist_keys[dist]] = torch.load(last_exp + "/tvd/" + dist).detach().cpu().numpy()[0:150]
    kl_diff[dist_keys[dist]] = torch.load(last_exp + "/kl/" + dist).detach().cpu().numpy()[0:150]



label_keys = tvd_diff.keys()



tvd_error, kl_error = {}, {}
new_tvd = {}
new_kl = {}
xaxis = []
for dist in tvd_diff:
    new_tvd[dist], new_kl[dist] = [], []
    tvd_error[dist], kl_error[dist] = [], []
    idx = 0
    while (idx + 1) * delta <= min(epochs, tvd_diff[dist].shape[0]):
        st, en = idx * delta, (idx + 1) * delta
        new_tvd[dist].append(np.mean(tvd_diff[dist][st: en]))
        new_kl[dist].append(np.mean(kl_diff[dist][st: en]))

        # tvd
        error = abs(tvd_diff[dist][idx * delta: (idx + 1) * delta] - new_tvd[dist][-1])
        tvd_error[dist].append(np.mean(error))

        # kl
        error = abs(kl_diff[dist][idx * delta: (idx + 1) * delta] - new_kl[dist][-1])
        kl_error[dist].append(np.mean(error))

        idx += 1

    xaxis = [i * delta for i in range(len(new_tvd[dist]))]

# dashed=list(new_dist.values())

label_keys = tvd_diff.keys()
plot_lines("Modular Training distribution convergence", "Total Variation Distance",
           list(new_tvd.values()), xaxis,
           list(label_keys), dashed, [], list(tvd_error.values()), save_plot=False,
           path=last_exp)

plot_lines("Modular Training distribution convergence", "KL Divergence",
           list(new_kl.values()), xaxis,
           list(label_keys), dashed, [], list(kl_error.values()), save_plot=False,
           path=last_exp)